{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FtOy24YspTE8"
      },
      "source": [
        "<span style=\"color: red;\">Requirement when running in Goolge Colab</span>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vKL4GQlKpTE9"
      },
      "outputs": [],
      "source": [
        "!pip install diffusers"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XeeymRQ8pTE-"
      },
      "source": [
        "#  Chapter 3 - Classifier Free Guidance\n",
        "\n",
        "Classifier-Free Guidance (CFG) emerged as a technique to significantly improve the quality and control of image generation in diffusion models like Stable Diffusion. Introduced by Ho and Salimans (https://arxiv.org/pdf/2207.12598) in 2021, CFG addresses the limitations of unconditional sampling, which often produces low-quality or irrelevant results. By interpolating between unconditional and text-conditioned outputs, CFG allows for better alignment between the generated image and the input prompt. This technique enhances image quality, increases prompt relevance, and gives users more control over the generation process. When using Stable Diffusion, applying CFG is crucial for producing high-quality, prompt-adhering images, making it an essential component in most practical applications of the model.\n",
        "\n",
        "Now we move to the implementation of it which is quite simple. the first part below is copied from Chatper 2 and it's exactly the same and you can run it and move to the next one:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VC5HbehNpTE_"
      },
      "outputs": [],
      "source": [
        "import warnings\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "import diffusers\n",
        "from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler\n",
        "from transformers import CLIPTextModel, CLIPTokenizer\n",
        "import torch\n",
        "import matplotlib.pyplot as plt\n",
        "from tqdm import tqdm\n",
        "\n",
        "model_id = \"stabilityai/stable-diffusion-2-1-base\"\n",
        "\n",
        "vae = AutoencoderKL.from_pretrained(\n",
        "    model_id, subfolder=\"vae\", revision=None, variant=\"fp16\"\n",
        ").to(\"cuda\")\n",
        "\n",
        "\n",
        "unet = UNet2DConditionModel.from_pretrained(\n",
        "    model_id, subfolder=\"unet\", revision=None, variant=\"fp16\"\n",
        ").to(\"cuda\")\n",
        "scheduler = DDIMScheduler.from_pretrained(model_id, subfolder=\"scheduler\")\n",
        "\n",
        "\n",
        "tokenizer = CLIPTokenizer.from_pretrained(\n",
        "    model_id, subfolder=\"tokenizer\", revision=None, variant=\"fp16\"\n",
        ")\n",
        "text_encoder = CLIPTextModel.from_pretrained(\n",
        "    model_id, subfolder=\"text_encoder\", revision=None, variant=\"fp16\"\n",
        ").to(\"cuda\")\n",
        "\n",
        "\n",
        "prompt = \"A photo of a woman, straight hair, light blonde and pink hair, smiling expression, grey background\"\n",
        "\n",
        "text_inputs = tokenizer(\n",
        "                prompt,\n",
        "                padding=\"max_length\",\n",
        "                max_length=tokenizer.model_max_length,\n",
        "                truncation=True,\n",
        "                return_tensors=\"pt\",\n",
        "            ).input_ids.to(\"cuda\")\n",
        "prompt_embeds = text_encoder(text_inputs)[0]\n",
        "\n",
        "latents = torch.randn((1, unet.in_channels, unet.config.sample_size, unet.config.sample_size), generator=torch.Generator().manual_seed(220)).to(\"cuda\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-wbWg9vlpTE_"
      },
      "source": [
        "The notion of CFG require an unconditional state of the model for sampling that we then use at every step to balance the weights of the latents. Therefore we tokenise an empty prompt as our unconditional state and then encode the text"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6HIwQTm6pTFA"
      },
      "outputs": [],
      "source": [
        "uncond_tokens = \"\"\n",
        "max_length = prompt_embeds.shape[1]\n",
        "uncond_input = tokenizer(\n",
        "    uncond_tokens,\n",
        "    padding=\"max_length\",\n",
        "    max_length=max_length,\n",
        "    truncation=True,\n",
        "    return_tensors=\"pt\").input_ids.to(\"cuda\")\n",
        "uncond_prompt_embeds = text_encoder(uncond_input)[0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VmVDyz6tpTFA"
      },
      "source": [
        "Now we combine our unconditional promopt with our conditional promopt to run inference in each step for our sampling process for both bathes"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vF5z3AL_pTFA"
      },
      "outputs": [],
      "source": [
        "prompt_embeds_combined = torch.cat([uncond_prompt_embeds, prompt_embeds])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hPNHpzoEpTFB"
      },
      "source": [
        "the core difference to the previous chapter is that we duplicate our latents to run inference in each step for two batches of the same latent on different prompt condition\n",
        "    latent_model_input = torch.cat([latents] * 2)\n",
        "and based of the noise prediction for unconditional and conditional state with predict a relevant noise that is then use to generate the next latent for the next step of our sampling process"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gjyh57rKpTFB"
      },
      "outputs": [],
      "source": [
        "num_inference_steps = 50\n",
        "guidance_scale = 7.5\n",
        "scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n",
        "timesteps = scheduler.timesteps\n",
        "\n",
        "with torch.no_grad():\n",
        "    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=\"Inference steps\"):\n",
        "\n",
        "        latent_model_input = torch.cat([latents] * 2)\n",
        "\n",
        "        noise_pred = unet(\n",
        "            latent_model_input,\n",
        "            t,\n",
        "            encoder_hidden_states=prompt_embeds_combined,\n",
        "            cross_attention_kwargs=None,\n",
        "            return_dict=False,\n",
        "        )[0]\n",
        "\n",
        "\n",
        "        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
        "        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
        "\n",
        "        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cQEle_1epTFB"
      },
      "source": [
        "as before to bring it to the pixel space with decode the latent with vae and now we should hopefully a significantly better result"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DYNXzbtHpTFC"
      },
      "outputs": [],
      "source": [
        "with torch.no_grad():\n",
        "    image = vae.decode(latents / vae.config.scaling_factor, return_dict=False)[0]\n",
        "    image_np = image.squeeze(0).float().permute(1,2,0).detach().cpu()\n",
        "    image_np = image_np - image_np.min()\n",
        "    image_np = image_np / image_np.max()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "whUJQ0Q5pTFC"
      },
      "outputs": [],
      "source": [
        "plt.imshow(image_np)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.11"
    },
    "colab": {
      "provenance": [],
      "gpuType": "L4"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}